MRI data preprocessed here 📒, which is a processed version of this dataset 💼
Description of the dataset¶
This brain tumor dataset containing 3064 T1-weighted contrast-inhanced images from 233 patients with three kinds of brain tumor: meningioma (708 slices), glioma (1426 slices), and pituitary tumor (930 slices). Due to the file size limit of repository, we split the whole dataset into 4 subsets, and achive them in 4 .zip files with each .zip file containing 766 slices.The 5-fold cross-validation indices are also provided.
This data is organized in matlab data format (.mat file). Each file stores a struct containing the following fields for an image:
- cjdata.label: 1 for meningioma, 2 for glioma, 3 for pituitary tumor 3️⃣
- cjdata.PID: patient ID
- cjdata.image: image data
- cjdata.tumorBorder: a vector storing the coordinates of discrete points on tumor border.
For example, [x1, y1, x2, y2,...] in which x1, y1 are planar coordinates on tumor border. It was generated by manually delineating the tumor border. So we can use it to generate binary image of tumor mask.- cjdata.tumorMask: a binary image with 1s indicating tumor region
This data was used in the following paper:
- Cheng, Jun, et al. "Enhanced Performance of Brain Tumor Classification via Tumor Region Augmentation and Partition." PloS one 10.10 (2015).
- Cheng, Jun, et al. "Retrieval of Brain Tumors by Adaptive Spatial Pooling and Fisher Vector Representation." PloS one 11.6 (2016). Matlab source codes are available on github https://github.com/chengjun583/brainTumorRetrieval
from forgebox.imports import *
from tqdm.notebook import tqdm
import pytorch_lightning as pl
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import interact
DATA = Path("/GCI/brain_mri/")
MATS = DATA/"mats"
NUMPYS = DATA/"npy"
WEIGHTS = DATA/"weights"
WEIGHTS.mkdir(exist_ok = True)
A meta data pandas dataframe makes the info about each image
Column informations¶
- pid: patient id
- img: location of the image numpy
- label: 1 for meningioma, 2 for glioma, 3 for pituitary tumor
- shape: shape of the image, here we use only the 512x512
- img_id: original mat file id of the image
df = pd.read_csv(DATA/"meta.csv")
df["img_id"] = df.img.apply(lambda x:int(Path(x).name.split('.')[0]))
df = df.query("shape=='512_512'").sort_values(by=["img_id"]).reset_index(drop=True)
df.sample(10)
df.vc("label")
def vis_patient(pid):
sub_df = df.query(f"pid=='{pid}'").sort_values(by="img_id")
img_arr = np.stack(list(np.load(i) for i in sub_df.img))\
.astype(np.float32)/1000
mask_arr = np.stack(list(np.load(i) for i in sub_df["mask"]))\
.astype(np.float32)
@interact
def show_mri(i = (1,len(img_arr))):
print(list(sub_df.img)[i-1])
rgb_arr = np.stack([
mask_arr[i-1],
np.clip(img_arr[i-1]-mask_arr[i-1],0.,1.),
img_arr[i-1],
], axis=-1)
# rgb_arr = img_arr[i-1].astype(np.float32)
# print(rgb_arr[200:230,200:230])
display(plt.imshow(rgb_arr))
vis_patient('100360')
class mri_data(Dataset):
def __init__(self, df: pd.DataFrame):
super().__init__()
self.df = df.reset_index(drop = True)
def __len__(self):
return len(self.df)
def __repr__(self):
return f"MRI Dataset:\n\t{len(self.df.pid.unique())} patients, {len(self)} slices"
def __getitem__(self,idx):
row = dict(self.df.loc[idx])
img = np.load(row["img"])
img = img/(img.max())
return img[None, ...], row['label']-1
def split_by(
df: pd.DataFrame,
col: str,
val_ratio: float=.2
):
"""
split the train/ valid ratio from the unique value
of a certain column
by certain ratio
- col: the certain column
- val_ratio: certain ratio
"""
uniques = np.array(list(set(list(df[col]))))
validation_ids = np.random.choice(
uniques, size=int(len(uniques)*val_ratio), replace=False)
val_slice = df[col].isin(validation_ids)
return df[~val_slice].sample(frac=1.).reset_index(drop=True),\
df[val_slice].reset_index(drop=True)
train_df, val_df = split_by(df, "pid")
total_ds = mri_data(df)
train_ds = mri_data(train_df)
val_ds = mri_data(val_df)
train_ds, val_ds
x, y = train_ds[5]
x.shape, y
Mean & standard variation of the entire dataset, we need them for the preprocessing layer normalization
all_x = []
for i in tqdm(range(len(total_ds))):
x,yy = total_ds[i]
all_x.append(np.array([x.mean(), x.std()]))
all_arr = np.array(all_x)
x_mean, x_std = all_arr.mean(0)
x_mean, x_std
all_arr[:,0].min(), all_arr[:,0].max(),all_arr[:,1].min(), all_arr[:,1].max()
Experiments with Unet
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import Conv2dStaticSamePadding
model = EfficientNet.from_pretrained("efficientnet-b5", num_classes=3)
model._conv_stem = Conv2dStaticSamePadding(
1, 48, kernel_size=(3, 3), stride=(2, 2), bias=False, image_size=512
)
model(torch.FloatTensor(x)[None,...]).shape
class PlData(pl.LightningDataModule):
def __init__(self, train_df, val_df, bs):
super().__init__()
self.bs = bs
self.train_df = train_df
self.val_df = val_df
self.train_ds = mri_data(self.train_df)
self.val_ds = mri_data(self.val_df)
def train_dataloader(self):
return DataLoader(
self.train_ds,
shuffle=True,
num_workers=8,
batch_size=self.bs
)
def val_dataloader(self):
"""
validation dataloader loader
batch size = train batch size x 2
"""
return DataLoader(
self.val_ds,
shuffle=False,
num_workers=8,
batch_size=self.bs * 2
)
class PlMRIModel(pl.LightningModule):
def __init__(self, base_model):
super().__init__()
self.base = base_model
self.softmax = nn.Softmax(dim=-1)
self.crit = nn.CrossEntropyLoss()
self.accuracy_f = pl.metrics.Accuracy()
def forward(self, x):
return self.base(x)
def configure_optimizers(self):
return torch.optim.AdamW(self.base.parameters(), lr=1e-4)
def calc_all_metrics(
self,
y_, y, is_train
):
phase = "train" if is_train else "val"
logits = self.softmax(y_)
acc = self.accuracy_f(logits, y)
self.log(f'{phase}_acc', acc)
def training_step(self, batch, batch_idx):
x,y = batch
x = x.float(); y=y.long()
y_ = self(x)
loss = self.crit(y_, y)
self.log('train_loss', loss)
self.calc_all_metrics(y_, y, True)
return loss
def validation_step(self, batch, batch_idx):
x,y = batch
x = x.float(); y=y.long()
y_ = self(x)
loss = self.crit(y_, y)
self.log('val_loss', loss)
self.calc_all_metrics(y_, y, False)
return loss
pl_data = PlData(train_df, val_df, bs=8)
pl_model = PlMRIModel(model)
# loggers
logger = pl.loggers.TensorBoardLogger("/GCI/tensorboard/brain_mri", name="cls")
# callbacks
early = pl.callbacks.EarlyStopping(monitor="val_acc")
saving = pl.callbacks.ModelCheckpoint(str(WEIGHTS/"cls_models"), monitor="val_acc", save_top_k = 3, mode="max")
trainer = pl.Trainer(
logger=logger,
callbacks=[early, saving],
checkpoint_callback=True,
gpus=1,
fast_dev_run=False,
)
trainer.fit(pl_model,pl_data)
Switch to evaluation mode
pl_model = pl_model.eval()
An array to array pipeline
from forgebox.ftorch.cuda import CudaHandler
cu = CudaHandler()
dev = cu.idle()()
pl_model = pl_model.to(dev)
def pred(x: np.array) -> float:
"""
predict classification probabilities from image array
"""
with torch.no_grad():
return pl_model.softmax(
pl_model(torch.FloatTensor(x).to(dev))).cpu().detach().numpy()[0]
@interact
def see_val(idx = (0,len(val_ds))):
x,y = val_ds[idx]
print(f"Prediction, {pred(x[None,:])}, label {np.eye(3)[int(y)]}")
plt.imshow(x[0])
preds = []
labels = []
for i in tqdm(range(len(val_ds))):
x,y = val_ds[i]
preds.append(pred(x[None, :]))
labels.append(y)
pred_arr = np.stack(preds)
pred_df = pd.DataFrame(dict(
idx=range(len(val_ds)),
meningioma=pred_arr[:,0],
glioma=pred_arr[:,1],
pituitary=pred_arr[:,2],
pred_idx=pred_arr.argmax(-1),
labels=labels
))
from forgebox.images import widgets
def full_getter(cls,idx):
row = dict(cls.df.loc[idx])
img = np.load(row["img"])
mask = np.load(row["mask"])
img = img/(img.max())
return img[None, ...], mask, row['label']-1
mri_data.full_getter = full_getter
def create_img(i):
x, y, z = val_ds.full_getter(i)
img = x[0]
y = y.astype(np.float32)
img_arr = np.stack([np.zeros_like(img),img,img],axis=-1)
img_arr2 = np.stack([y*.5,img,img],axis=-1)
img_all = np.concatenate([img_arr, img_arr2], axis=1)
return Image.fromarray((img_all*256).astype(np.byte), mode="RGB")
cancer_types = ["meningioma","glioma", "pituitary"]
def view_top(cancer_type):
"""
the top confident ones
"""
top_ = pred_df.sort_values(by=cancer_type, ascending=False).head(12)
img_list = list(create_img(i) for i in top_.idx)
display(top_)
widgets.view_images(*img_list, num_per_row=2)()
def view_top_error(cancer_type):
"""
the top error ones
"""
top_ = pred_df.query(f"pred_idx!=labels")\
.sort_values(by=cancer_type, ascending=False).head(12)
img_list = list(create_img(i) for i in top_.idx)
display(top_)
widgets.view_images(*img_list, num_per_row=2)()
_ = view_top("meningioma")
_ = view_top("glioma")
_ = view_top("pituitary")
_ = view_top_error("meningioma")
_ = view_top_error("glioma")
_ = view_top_error("pituitary")